# -*- coding: utf-8 -*-
"""
Created on Fri Sep 17 10:29:29 2021
用 tf2 写一个从 任意态 态制备到 |1> 态的 DQN 算法
@author: Waikikilick
"""

import  os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2' #设置 tf 的报警等级
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers,optimizers,losses,initializers,Sequential,metrics,models
import copy 
from collections import deque
import random
from scipy.linalg import expm
from time import *

tf.random.set_seed(1)
np.random.seed(1)
random.seed (1)    

class env(object):
    def  __init__(self, 
        action_space = [0,1], #允许的动作，默认两个分立值，只是默认值，真正值由调用时输入
        dt = 0.1,
        ): 
        self.action_space = action_space
        self.n_actions = len(self.action_space)
        self.n_features = 4 #描述状态所用的长度
        self.target_psi =  np.mat([[0], [1]], dtype=complex) #最终的目标态为 |0>,  np.array([1,0,0,0])
        self.s_x = np.mat([[0, 1], [1, 0]], dtype=complex)
        self.s_z = np.mat([[1, 0], [0, -1]], dtype=complex)
        self.dt = dt
        self.training_set, self.validation_set, self.testing_set = self.psi_set()
        
        self.noise_normal = np.array([ 1.62434536, -0.61175641, -0.52817175, -1.07296862,  0.86540763,
       -2.3015387 ,  1.74481176, -0.7612069 ,  0.3190391 , -0.24937038,
        1.46210794, -2.06014071, -0.3224172 , -0.38405435,  1.13376944,
       -1.09989127, -0.17242821, -0.87785842,  0.04221375,  0.58281521,
       -1.10061918,  1.14472371,  0.90159072,  0.50249434,  0.90085595,
       -0.68372786, -0.12289023, -0.93576943, -0.26788808,  0.53035547,
       -0.69166075, -0.39675353, -0.6871727 , -0.84520564, -0.67124613,
       -0.0126646 , -1.11731035,  0.2344157 ,  1.65980218,  0.74204416,
       -0.19183555])
        #noise_normal 为均值为 0 ，标准差为 1 的正态分布的随机数组成的数组
        #该随机数由 np.random.seed(1) 生成: np.random.seed(1) \ noise_uniform = np.random.normal(loc=0.0, scale=1.0, size=41)
        
        self.noise_constant = np.ones(41)
        
        
    def psi_set(self):
        
        theta_num = 6 # 除了 0 和 Pi 两个点之外，点的数量
        varphi_num = 21 # varphi 角度一圈上的点数
        # 总点数为 theta_num * varphi_num + 2(布洛赫球两极) # 6 * 21 + 2 = 512
        
        theta = np.delete(np.linspace(0,np.pi,theta_num+1,endpoint=False),[0])
        varphi = np.linspace(0,np.pi*2,varphi_num,endpoint=False) 
        
        psi_set = []
        for ii in range(theta_num):
            for jj in range(varphi_num):
                psi_set.append(np.mat([[np.cos(theta[ii]/2)],[np.sin(theta[ii]/2)*(np.cos(varphi[jj])+np.sin(varphi[jj])*(0+1j))]]))
        psi_set.append(np.mat([[1], [0]], dtype=complex))
        psi_set.append(np.mat([[0], [1]], dtype=complex))
        random.shuffle(psi_set) # 打乱点集
    
        training_set = psi_set[0:32]
        validation_set = psi_set[32:64]
        testing_set = psi_set[64:128]
        
        return training_set, validation_set, testing_set
        
    
    def reset(self, init_psi): # 在一个新的回合开始时，归位到开始选中的那个点上
        
        init_state = np.array([init_psi[0,0].real, init_psi[1,0].real, init_psi[0,0].imag, init_psi[1,0].imag])
        # np.array([1实，2实，1虚，2虚])
        return init_state
    
    
    def step(self, state, action, nstep):
        
        psi = np.array([state[0:int(len(state) / 2)] + state[int(len(state) / 2):int(len(state))] * 1j])
        #array([[1实 + 1虚j, 2实 + 2虚j]])
        
        psi = psi.T
        #array([[1实 + 1虚j],
        #        2实 + 2虚j]])

        psi = np.mat(psi) 
        #matrix([[ 1实 + 1虚j],
        #        [ 2实 + 2虚j]])
        
        H =  float(action)* self.s_z + 1 * self.s_x
        U = expm(-1j * H * self.dt) 
        psi = U * psi  # next state

        err = 10e-4
        fid = (np.abs(psi.H * self.target_psi) ** 2).item(0).real  
        rwd = (fid)
        
        done = (((1 - fid) < err) or nstep >= 2 * np.pi / self.dt) 

        #再将量子态的 psi 形式恢复到 state 形式。 # 因为网络输入不能为复数，否则无法用寻常基于梯度的算法进行反向传播
    
        psi = np.array(psi)
        psi_T = psi.T
        state = np.array(psi_T.real.tolist()[0] + psi_T.imag.tolist()[0]) 

        return state, rwd, done, fid   
    
    def step_noise_J(self, state, action, nstep):
        
        psi = np.array([state[0:int(len(state) / 2)] + state[int(len(state) / 2):int(len(state))] * 1j])
        psi = psi.T
        psi = np.mat(psi)         
        H =  (float(action) + self.noise[nstep]) * self.s_z +  self.s_x
        U = expm(-1j * H * self.dt) 
        psi = U * psi  
        fid = (np.abs(psi.H * self.target_psi) ** 2).item(0).real  
        psi = np.array(psi)
        psi_T = psi.T
        state = np.array(psi_T.real.tolist()[0] + psi_T.imag.tolist()[0]) 

        return state, fid   
    
    def step_noise_h(self, state, action, nstep):
        
        psi = np.array([state[0:int(len(state) / 2)] + state[int(len(state) / 2):int(len(state))] * 1j])
        psi = psi.T
        psi = np.mat(psi)         
        H =  float(action) * self.s_z +  ( 1 + self.noise[nstep]) * self.s_x
        U = expm(-1j * H * self.dt) 
        psi = U * psi  
        fid = (np.abs(psi.H * self.target_psi) ** 2).item(0).real  
        psi = np.array(psi)
        psi_T = psi.T
        state = np.array(psi_T.real.tolist()[0] + psi_T.imag.tolist()[0]) 

        return state, fid   
    
def relu(vector):
    for i in range(len(vector)):
        vector[i] = np.maximum(0,vector[i])
    return vector

def agent_matrix(x):
    out = np.mat(x) * net[0].numpy() + net[1].numpy()
    out = relu(out)
    out = out * net[2].numpy() + net[3].numpy()
    out = relu(out)
    out = out * net[4].numpy() + net[5].numpy()
    out = relu(out)
    action = np.argmax(out)
    return action


def testing(): # 测试 测试集中的点 得到 保真度 分布
    print('\n测试中, 请稍等...')
    
    testing_set = env.testing_set
    fid_list = []
    
    for test_init_psi in testing_set:
        
        fid_max = 0
        observation = env.reset(test_init_psi)
        nstep = 0 
        
        while True:
            action = agent_matrix(observation) 
            observation_, reward, done, fid = env.step(observation, action, nstep)  
            nstep += 1
            fid_max = max(fid_max, fid)
            observation = observation_
                
            if done:
                break
            
        fid_list.append(fid_max)
        
    return fid_list

def testing_noise_J_dynamic(noise_a): # #测试部分 ( J 动态噪声环境 ）
    print('\n测试 J 动态噪声中, 请稍等...')
    print('噪声振幅为：', noise_a)
    
    env.noise = noise_a * env.noise_normal #uniform
    fid_noise_list = []
    
    for test_init_psi in testing_set:
        
        fid_max = 0
        observation = env.reset(test_init_psi)
        nstep = 0 
        action_list = [] #用来保存本回合所采取的动作，用于噪声分析
        fid_list = [] #用来保存本回合中的保真度，选择最大保真度对应的步骤作为后面噪声环境中动作的终止步骤
        
        while True:
            action = agent_matrix(observation) 
            observation_, reward, done, fid = env.step(observation, action, nstep)              
            action_list.append(action)
            fid_list.append(fid)
            nstep += 1
            fid_max = max(fid_max, fid)
            observation = observation_
                
            if done:
                break
            
        max_index = fid_list.index(max(fid_list))
        action_list = action_list[0:max_index+1]
        
        observation = env.reset(test_init_psi)
        
        # 加入噪声
        nstep = 0
        for action in action_list:
            
            observation_, fid = env.step_noise_J(observation, action, nstep)
            observation = observation_
            nstep += 1
             #选择最后一步的保真度作为本回合的保真度
            
        fid_noise_list.append(fid) # 将最终保真度记录到矩阵中
            
    return fid_noise_list

def testing_noise_J_static(noise_a): # #测试部分 ( J 动态噪声环境 ）
    print('\n测试 J 静态噪声中, 请稍等...')
    print('噪声振幅为：', noise_a)
    
    env.noise = noise_a * env.noise_constant 
    fid_noise_list = []
    
    for test_init_psi in testing_set:
        
        fid_max = 0
        observation = env.reset(test_init_psi)
        nstep = 0 
        action_list = [] #用来保存本回合所采取的动作，用于噪声分析
        fid_list = [] #用来保存本回合中的保真度，选择最大保真度对应的步骤作为后面噪声环境中动作的终止步骤
        
        while True:
            action = agent_matrix(observation) 
            observation_, reward, done, fid = env.step(observation, action, nstep)              
            action_list.append(action)
            fid_list.append(fid)
            nstep += 1
            fid_max = max(fid_max, fid)
            observation = observation_
                
            if done:
                break
            
        max_index = fid_list.index(max(fid_list))
        action_list = action_list[0:max_index+1]
        
        observation = env.reset(test_init_psi)
        
        # 加入噪声
        nstep = 0
        for action in action_list:
            
            observation_, fid = env.step_noise_J(observation, action, nstep)
            observation = observation_
            nstep += 1
             #选择最后一步的保真度作为本回合的保真度
            
        fid_noise_list.append(fid) # 将最终保真度记录到矩阵中
            
    return fid_noise_list


def testing_noise_h_dynamic(noise_a): # #测试部分 ( h 动态噪声环境 ）

    print('\n测试 h 动态噪声中, 请稍等...')
    print('噪声振幅为：', noise_a)
    
    env.noise = noise_a * env.noise_normal 
    fid_noise_list = []
    
    for test_init_psi in testing_set:
        
        fid_max = 0
        observation = env.reset(test_init_psi)
        nstep = 0 
        action_list = [] #用来保存本回合所采取的动作，用于噪声分析
        fid_list = [] #用来保存本回合中的保真度，选择最大保真度对应的步骤作为后面噪声环境中动作的终止步骤
        
        while True:
            action = agent_matrix(observation) 
            observation_, reward, done, fid = env.step(observation, action, nstep)              
            action_list.append(action)
            fid_list.append(fid)
            nstep += 1
            fid_max = max(fid_max, fid)
            observation = observation_
                
            if done:
                break
            
        max_index = fid_list.index(max(fid_list))
        action_list = action_list[0:max_index+1]
        
        observation = env.reset(test_init_psi)
        
        # 加入噪声
        nstep = 0
        for action in action_list:
            
            observation_, fid = env.step_noise_h(observation, action, nstep)
            observation = observation_
            nstep += 1
             #选择最后一步的保真度作为本回合的保真度
            
        fid_noise_list.append(fid) # 将最终保真度记录到矩阵中
            
    return fid_noise_list

def testing_noise_h_static(noise_a): # #测试部分 ( h 动态噪声环境 ）

    print('\n测试 h 静态噪声中, 请稍等...')
    print('噪声振幅为：', noise_a)
    
    env.noise = noise_a * env.noise_constant 
    fid_noise_list = []
    
    for test_init_psi in testing_set:
        
        fid_max = 0
        observation = env.reset(test_init_psi)
        nstep = 0 
        action_list = [] #用来保存本回合所采取的动作，用于噪声分析
        fid_list = [] #用来保存本回合中的保真度，选择最大保真度对应的步骤作为后面噪声环境中动作的终止步骤
        
        while True:
            action = agent_matrix(observation) 
            observation_, reward, done, fid = env.step(observation, action, nstep)              
            action_list.append(action)
            fid_list.append(fid)
            nstep += 1
            fid_max = max(fid_max, fid)
            observation = observation_
                
            if done:
                break
            
        max_index = fid_list.index(max(fid_list))
        action_list = action_list[0:max_index+1]
        
        observation = env.reset(test_init_psi)
        
        # 加入噪声
        nstep = 0
        for action in action_list:
            
            observation_, fid = env.step_noise_h(observation, action, nstep)
            observation = observation_
            nstep += 1
             #选择最后一步的保真度作为本回合的保真度
            
        fid_noise_list.append(fid) # 将最终保真度记录到矩阵中
            
    return fid_noise_list
                 
    
if __name__ == "__main__":
    
    dt = np.pi/10
    
    network = keras.models.load_model('dqn_tf2_single_qubit_to_1_saved_model') #导入训练好的网络 
    net = network.trainable_variables   
    
    env = env(action_space = list(range(4)),   #允许的动作数 0 ~ 4-1 也就是 4 个
               dt = dt)
    
    testing_set = env.testing_set   
    
    

    # 测试
    testing_fid_list = testing()
    print('测试集平均保真度为：', np.mean(testing_fid_list))
    
    # 测试 噪声 对 保真度 的影响
    fids_noise_list = []
    
    # for noise_a in [-0.1, -0.09, -0.08, -0.07, -0.06, -0.05, -0.04, -0.03, -0.02, -0.01, 0, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1]:
    #     test_noise_mean = np.mean(testing_noise_J_static(noise_a))
    #     # test_noise_mean = np.mean(testing_noise_h_static(noise_a))
    #     print('噪声下的平均保真度为：',test_noise_mean)
    #     fids_noise_list.append(test_noise_mean)
        
    for noise_a in [0, 0.02, 0.04, 0.06, 0.08, 0.10, 0.12, 0.14, 0.16, 0.18, 0.2]:
        # test_noise_mean = np.mean(testing_noise_J_dynamic(noise_a))
        test_noise_mean = np.mean(testing_noise_h_dynamic(noise_a))
        print('噪声下的平均保真度为：',test_noise_mean)
        fids_noise_list.append(test_noise_mean) 
    
    print('测试集在噪声下平均保真度为', fids_noise_list)
    
# J_static: [0.9794733402807905, 0.9820849282426767, 0.9845052989593666, 0.9867231627334735, 0.9887267641616047, 0.9905039384875381, 0.992042180019179, 0.9933287220702951, 0.9943506277038209, 0.9950948903779574, 0.9955485434320497, 0.9956987771985925, 0.9955330623930708, 0.9950392783167978, 0.994205844311081, 0.9930218528256434, 0.9914772024111753, 0.9895627289165626, 0.9872703331654761, 0.9845931034057308, 0.9815254308670808]
#           [0.97947334 0.98208493 0.9845053  0.98672316 0.98872676 0.99050394 0.99204218 0.99332872 0.99435063 0.99509489 0.99554854 0.99569878 0.99553306 0.99503928 0.99420584 0.99302185 0.9914772  0.98956273 0.98727033 0.9845931  0.98152543]

# h_static: [0.982549073079168, 0.9847851255248172, 0.9868302790335894, 0.9886781102081892, 0.9903221752078595, 0.9917560265100422, 0.9929732309890851, 0.9939673893016825, 0.9947321565594855, 0.9952612642594008, 0.9955485434320497, 0.9955879489582835, 0.9953735849927705, 0.9948997314228943, 0.9941608712796968, 0.9931517190067585, 0.9918672494815173, 0.990302727672548, 0.9884537388055373, 0.9863162189001062, 0.9838864855295902]
#           [0.98254907 0.98478513 0.98683028 0.98867811 0.99032218 0.99175603 0.99297323 0.99396739 0.99473216 0.99526126 0.99554854 0.99558795 0.99537358 0.99489973 0.99416087 0.99315172 0.99186725 0.99030273 0.98845374 0.98631622 0.98388649]

# J_dynamic: [0.9955485434320497, 0.995553035008053, 0.9953816474960144, 0.9950311192817811, 0.9944984094289739, 0.993780713565465, 0.9928754794401528, 0.9917804220702768, 0.9904935383981217, 0.9890131213750053, 0.98733777338989]
             # [0.99554854 0.99555304 0.99538165 0.99503112 0.99449841 0.99378071 0.99287548 0.99178042 0.99049354 0.98901312 0.98733777]
             
# h_dynamic: [0.9955485434320497, 0.9949927376494851, 0.9938545857191387, 0.9921271183936076, 0.9898053814654768, 0.9868864954983084, 0.9833697038779925, 0.9792564086539373, 0.9745501937352331, 0.969256835108737, 0.9633842978528595]
             # [0.99554854 0.99499274 0.99385459 0.99212712 0.98980538 0.9868865 0.9833697  0.97925641 0.97455019 0.96925684 0.9633843 ]